import re
import jieba
import copy
import numpy as np
from sklearn.model_selection import train_test_split
from gensim.models.word2vec import Word2Vec
from nlu_model.util.pkl_impl import save_pkl, load_pkl
from nlu_model.cls.model_pytorch.model_train import TrainModelPipeline

punctuation = r"[\s+\.\!\/_,$%^*(+\"\'\[\]]+|[+——！，。？、~@#￥%……&*（）：]"
pun = re.compile(punctuation)

TITLE_DICT = {"other": 0}

def cw(x): 
    x = pun.sub("", x)

    return list(jieba.cut(x))

def process_label(label, mode = "other"):
    if label in TITLE_DICT:
        return TITLE_DICT[label]
    else:
        if mode == "add":
            TITLE_DICT[label] = copy.deepcopy(len(TITLE_DICT))
            return TITLE_DICT[label]
        else:
            return TITLE_DICT["other"]

def data_loader():
    x_train, x_test, y_train, y_test = [], [], [], []
    with open("data/cls/news_title/train_file.txt") as f:
        for line in f:
            ll = line.strip().split("\t")
            label = ll[0]
            sentence = "".join(ll[0])

            y_train.append(process_label(label, mode = "add"))
            x_train.append(cw(sentence))
    
    with open("data/cls/news_title/test_file.txt") as f:
        for line in f:
            ll = line.strip().split("\t")
            label = ll[0]
            sentence = "".join(ll[0])

            y_test.append(process_label(label))
            x_test.append(cw(sentence))
    
    print("train data len: %s" % len(x_train))
    print("test data len: %s" % len(x_test))

    return x_train, x_test, y_train, y_test

def pretrain_w2v(x_train):
    N_DIM = 300                         # word2vec的数量
    MIN_COUNT = 5                       # 保证出现的词数足够做才进入词典
    w2v_EPOCH = 15                      # w2v的训练迭代次数
    MAXLEN = 50                         # 句子最大长度
    # Initialize model and build vocab
    imdb_w2v = Word2Vec(size=N_DIM, min_count=MIN_COUNT)
    imdb_w2v.build_vocab(x_train)

    # Train the model over train_reviews (this may take several minutes)
    imdb_w2v.train(x_train, total_examples=len(x_train), epochs=w2v_EPOCH)

    print("model train done")

    # word2vec后处理
    n_symbols = len(imdb_w2v.wv.vocab.keys()) + 2
    embedding_weights = [[0 for i in range(N_DIM)] for i in range(n_symbols)]
    np.zeros((n_symbols, 300))
    idx = 1
    word2idx_dic = {}
    w2v_model_metric = []
    for w in imdb_w2v.wv.vocab.keys():
        embedding_weights[idx] = imdb_w2v[w]
        word2idx_dic[w] = idx
        idx = idx + 1

    # 留给未登录词的位置
    avg_weights = [0 for i in range(N_DIM)]
    for wd in word2idx_dic:
        avg_weights = [(avg_weights[idx]+embedding_weights[word2idx_dic[wd]][idx]) for idx in range(N_DIM)]
    avg_weights = [avg_weights[idx] / len(word2idx_dic) for idx in range(N_DIM)]
    embedding_weights[idx] = avg_weights
    word2idx_dic["<UNK>"] = idx

    # 留给pad的位置
    word2idx_dic["<PAD>"] = 0

    # 保存w2id词典
    save_pkl('./data/ptm/news_title/w2v_word2idx2021072601.pkl', word2idx_dic)
    with open('./data/ptm/news_title/w2v_word2idx2021072601.txt', 'w') as f:
        for item in word2idx_dic:
            f.write("%s\t%s\n" % (item, word2idx_dic[item]))

    # 保存词向量矩阵
    save_pkl("./data/ptm/news_title/w2v_word2idx2021072601.pkl", embedding_weights)
    with open("./data/ptm/news_title/w2v_word2idx2021072601.txt", "w") as f:
        for line in embedding_weights:
            f.write("%s/n" % (",".join([str(i) for i in line])))

    save_pkl("./data/ptm/news_title/w2v_word2idx2021072601.pkl", [N_DIM, MIN_COUNT, w2v_EPOCH, MAXLEN])

    return embedding_weights,imdb_w2v, word2idx_dic



def word2idx(source_data, word2idx_dic, padding=-1):
    result_data = []
    for idx in range(len(source_data)):
        sentence = []
        for item in source_data[idx]:
            if item in word2idx_dic:
                sentence.append(word2idx_dic[item])
            else:
                sentence.append(len(word2idx_dic)-1)
        if padding >= 0:
            if len(sentence) > padding:
                sentence = sentence[:padding]
            else:
                while len(sentence) < padding:
                    sentence.append(0)
        result_data.append(np.array(sentence))
    return result_data
